Skip to content

feature(pdd): enable P/D disaggregation with NIXL host KV transfer#477

Open
rebel-ykchoi wants to merge 4 commits intodevfrom
feat_pd_disag
Open

feature(pdd): enable P/D disaggregation with NIXL host KV transfer#477
rebel-ykchoi wants to merge 4 commits intodevfrom
feat_pd_disag

Conversation

@rebel-ykchoi
Copy link
Copy Markdown
Contributor

@rebel-ykchoi rebel-ykchoi commented Mar 24, 2026

🚀 Summary of Changes

wire vLLM KV transfer to a RBLN-specific NIXL connector and host-side buffers so prefill/decode can run on separate engines with H2H transfer.

KV connector / registration

  • add RblnNixlConnector (scheduler/worker) extending upstream NixlConnector:
  • register connector name "RblnNixlConnector" in kv_connector factory.

Platform

  • expose NIXL hints: get_nixl_supported_devices (rbln -> cpu) and get_nixl_memory_type ("DRAM").

Scheduler (rbln_scheduler.py)

  • handle kv_consumer request to be scheduled with other requests in decode stage

Model runner (rbln_model_runner.py)

  • override maybe_get_kv_connector_output(..., wait_for_save) using last prefill chunk.
  • replace generic copy_kv_blocks with rbln_copy_kv_blocks using runtime _update_kv_cache / _fetch_kv_cache
  • bind_kv_cache_name + per-layer names for mark_static_address when compiling.

Attention backend (flash_attention.py)

  • report backend name as FLASH_ATTN for upstream compatibility.

Tests

  • tests/torch_compile/e2e/v1/kv_connector/nixl_integration/run_accuracy_test.sh

What does this PR do? What feature, fix, or improvement does it bring?


📌 Related Issues / Tickets

  • Resolves #
  • Related to #

✅ Type of Change

  • 🚀 Release (release)
  • ✨ Feature (feature)
  • 🧠 Model support (model)
  • 🧬 Core engine changes (core)
  • 🛠 Bug fix (fix)
  • ⚙️ Performance improvement (perf)
  • 🔁 Refactor or code cleanup (refactor)
  • 📄 Documentation (docs)
  • ❓ Other (other): please describe

🧪 How to Test

  1. Run ...
  2. Verify output: ...
  3. Edge case tested: ...

📸 Screenshots / Logs (if applicable)


📋 Checklist

  • PR title follows Conventional Commits format
  • This PR is linked to an existing issue
  • The test method is described, and the expected result is clearly stated
  • Relevant documentation has been updated (if applicable)

💬 Notes


wire vLLM KV transfer to a RBLN-specific NIXL connector and host-side
buffers so prefill/decode can run on separate engines with H2H transfer.

KV connector / registration
- add RblnNixlConnector (scheduler/worker) extending upstream NixlConnector:
- register connector name "RblnNixlConnector" in kv_connector factory.

Platform
- expose NIXL hints: get_nixl_supported_devices (rbln -> cpu) and
  get_nixl_memory_type ("DRAM").

Scheduler (rbln_scheduler.py)
- handle kv_consumer request to be scheduled with other requests in decode
stage

Model runner (rbln_model_runner.py)
- override maybe_get_kv_connector_output(..., wait_for_save)
using last prefill chunk.
- replace generic copy_kv_blocks with rbln_copy_kv_blocks using runtime
  _update_kv_cache / _fetch_kv_cache
- bind_kv_cache_name + per-layer names for mark_static_address when compiling.

Attention backend (flash_attention.py)
- Report backend name as FLASH_ATTN for upstream compatibility.

Examples
- add experimental examples/experimental/pd_disaggregation/toy_proxy_server.py
  (FastAPI proxy routing chat completions to prefill vs decode HTTP backends).
@rebel-ykchoi rebel-ykchoi changed the title feature: enable P/D disaggregation with NIXL host KV transfer feature(pdd): enable P/D disaggregation with NIXL host KV transfer Apr 15, 2026
@rebel-ykchoi rebel-ykchoi marked this pull request as ready for review April 15, 2026 10:32
@rebel-jaehwang rebel-jaehwang requested a review from Copilot April 15, 2026 10:36
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR wires vLLM KV transfer to an RBLN-specific NIXL connector and host-side buffers to enable prefill/decode disaggregation (with H2H KV transfer), plus related scheduler/model-runner/attention-backend integration and an E2E accuracy test harness.

Changes:

  • Add RblnNixlConnector and register it via a connector factory import hook.
  • Update worker/model-runner/scheduler and FlashAttention metadata plumbing to support P/D disaggregation and KV connector lifecycle.
  • Add an end-to-end NIXL integration test setup (proxy server + lm-eval accuracy test script).

Reviewed changes

Copilot reviewed 12 out of 18 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
vllm_rbln/v1/worker/utils.py Add bind_kv_cache_name helper to name KV cache buffers for compilation/static address marking.
vllm_rbln/v1/worker/rbln_worker.py Move KV-transfer init to initialize_from_config, add handshake metadata helper, add shutdown hook.
vllm_rbln/v1/worker/rbln_model_runner.py Integrate KV connector output handling, preemption handling, host-buffer KV copy ops, static KV cache naming.
vllm_rbln/v1/core/rbln_scheduler.py Adjust scheduling flow to allow remote-KV consumers to be batched with decode requests without mixing local prefill.
vllm_rbln/v1/attention/backends/flash_attention.py Switch prefill/decode detection to explicit is_prefill param; report backend name as FLASH_ATTN.
vllm_rbln/platform.py Expose NIXL device/memory hints for RBLN platform.
vllm_rbln/distributed/kv_transfer/kv_connector/v1/rbln_nixl_connector.py Implement RBLN-specific NIXL connector scheduler/worker behavior and host transfer buffers.
vllm_rbln/distributed/kv_transfer/kv_connector/factory.py Register connector name RblnNixlConnector in the upstream factory.
vllm_rbln/init.py Ensure connector factory registration runs when ops are registered.
tests/torch_compile/e2e/v1/kv_connector/nixl_integration/toy_proxy_server.py Add a proxy server to route prefill to prefiller instances and streaming decode to decoder instances.
tests/torch_compile/e2e/v1/kv_connector/nixl_integration/test_accuracy.py Add an lm-eval based accuracy test for the disaggregated setup.
tests/torch_compile/e2e/v1/kv_connector/nixl_integration/run_accuracy_test.sh Add a runner script that launches prefill/decode instances + proxy and executes the accuracy test.
(various __init__.py files) Package scaffolding for the new connector modules and tests.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 145 to 152
def sleep(self, level: int = 1) -> None:
logger.warning("sleep mode is not supported on RBLN, ignore it.")
pass

def wake_up(self, tags: list[str] | None = None) -> None:
logger.warning("sleep mode is not supported on RBLN, ignore it.")
pass

Copy link

Copilot AI Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

initialize_cache() was removed, but it’s still used by existing unit tests (e.g. tests/torch_compile/unit/v1/worker/test_rbln_worker.py::TestInitializeCache). This will break the worker interface expected by tests (and likely callers). Please restore initialize_cache() (or update callers/tests consistently) to keep cache_config.num_gpu_blocks/num_cpu_blocks configurable.

Copilot uses AI. Check for mistakes.
Comment on lines +4496 to +4506
if direction == "h2d":
kv_caches = src_kv_caches
copy_fn = runtime._update_kv_cache
else:
kv_caches = dst_kv_caches
copy_fn = runtime._fetch_kv_cache

for idx in src_block_ids:
for kv_name, kv_cache in kv_caches.items():
block_size = kv_cache.shape[-2]
copy_fn(kv_cache.data_ptr(), idx, 0, block_size, kv_name)
Copy link

Copilot AI Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rbln_copy_kv_blocks ignores dst_block_ids and always uses idx from src_block_ids when calling the runtime copy function. This breaks the expected src↔dst mapping when blocks need to be copied to different indices (e.g., compaction/relocation), and it also means dst_kv_caches is never used for the h2d case. Consider iterating with for src_id, dst_id in zip(src_block_ids, dst_block_ids) and passing the correct destination block id (and source/destination buffer) to the runtime API.

Suggested change
if direction == "h2d":
kv_caches = src_kv_caches
copy_fn = runtime._update_kv_cache
else:
kv_caches = dst_kv_caches
copy_fn = runtime._fetch_kv_cache
for idx in src_block_ids:
for kv_name, kv_cache in kv_caches.items():
block_size = kv_cache.shape[-2]
copy_fn(kv_cache.data_ptr(), idx, 0, block_size, kv_name)
for src_id, dst_id in zip(src_block_ids, dst_block_ids):
if direction == "h2d":
kv_caches = src_kv_caches
block_id = dst_id
copy_fn = runtime._update_kv_cache
else:
kv_caches = dst_kv_caches
block_id = src_id
copy_fn = runtime._fetch_kv_cache
for kv_name, kv_cache in kv_caches.items():
block_size = kv_cache.shape[-2]
copy_fn(
kv_cache.data_ptr(), block_id, 0, block_size, kv_name
)

Copilot uses AI. Check for mistakes.
Comment on lines +194 to +196

response = await client_info["client"].post(
endpoint, json=req_data, headers=headers
Copy link

Copilot AI Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

httpx.AsyncClient is configured with base_url=.../v1, but requests are made with endpoints like /completions and /chat/completions (leading slash). In httpx, a leading slash treats the request URL as absolute-path and will drop the /v1 prefix from base_url, so this proxy may call /completions on the backend instead of /v1/completions. Use relative paths (e.g., completions) or remove /v1 from base_url and keep /v1/... in the endpoints consistently (also applies to stream_service_response).

Suggested change
response = await client_info["client"].post(
endpoint, json=req_data, headers=headers
normalized_endpoint = endpoint.lstrip("/")
response = await client_info["client"].post(
normalized_endpoint, json=req_data, headers=headers

Copilot uses AI. Check for mistakes.
Comment thread tests/torch_compile/e2e/v1/kv_connector/nixl_integration/run_accuracy_test.sh Outdated
rebel-ykchoi and others added 2 commits April 16, 2026 10:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

torch.compile torch.compile based implementation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants